{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to XGBoost Spark with GPU\n",
    "\n",
    "The goal of this notebook is to show how to train a XGBoost Model with Spark RAPIDS XGBoost library on GPUs. The dataset used with this notebook is derived from Fannie Mae’s Single-Family Loan Performance Data with all rights reserved by Fannie Mae. This notebook uses XGBoost to train 12-month mortgage loan delinquency prediction model.\n",
    "\n",
    "## Load libraries\n",
    "First load some common libraries will be used by both GPU version and CPU version xgboost."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Intitializing Scala interpreter ..."
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Spark Web UI available at http://liyuan-gcp-init-v2206-m:8088/proxy/application_1658995519106_0006\n",
       "SparkContext available as 'sc' (version = 3.1.3, master = yarn, app id = application_1658995519106_0006)\n",
       "SparkSession available as 'spark'\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}\n",
       "import org.apache.spark.sql.SparkSession\n",
       "import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\n",
       "import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}\n"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}\n",
    "import org.apache.spark.sql.SparkSession\n",
    "import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator\n",
    "import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Besides CPU version requires some extra libraries, such as:\n",
    "\n",
    "```scala\n",
    "import org.apache.spark.ml.feature.VectorAssembler\n",
    "import org.apache.spark.sql.DataFrame\n",
    "import org.apache.spark.sql.functions._\n",
    "import org.apache.spark.sql.types.FloatType\n",
    "```"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set the dataset path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dataRoot: String = gs://rapids-test/yuanli-tools-eventlog-temp/mortgage-fannieMae/test20220728/output\n",
       "trainPath: String = gs://rapids-test/yuanli-tools-eventlog-temp/mortgage-fannieMae/test20220728/output/train/\n",
       "evalPath: String = gs://rapids-test/yuanli-tools-eventlog-temp/mortgage-fannieMae/test20220728/output/eval/\n",
       "transPath: String = gs://rapids-test/yuanli-tools-eventlog-temp/mortgage-fannieMae/test20220728/output/eval/\n"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// You need to update them to your real paths! The input data files is the output of mortgage-etl jobs\n",
    "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"gs://your-bucket/your-ETL-output-paths\")\n",
    "val trainPath = dataRoot + \"/train/\"\n",
    "val evalPath  = dataRoot + \"/eval/\"\n",
    "val transPath = dataRoot + \"/eval/\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build the schema and parameters\n",
    "The mortgage data has 27 columns: 26 features and 1 label. \"deinquency_12\" is the label column. The schema will be used to load data in the future.\n",
    "\n",
    "The next block also defines some key parameters used in xgboost training process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "labelColName: String = delinquency_12\n",
       "schema: org.apache.spark.sql.types.StructType = StructType(StructField(orig_channel,DoubleType,true), StructField(first_home_buyer,DoubleType,true), StructField(loan_purpose,DoubleType,true), StructField(property_type,DoubleType,true), StructField(occupancy_status,DoubleType,true), StructField(property_state,DoubleType,true), StructField(product_type,DoubleType,true), StructField(relocation_mortgage_indicator,DoubleType,true), StructField(seller_name,DoubleType,true), StructField(mod_flag,DoubleType,true), StructField(orig_interest_rate,DoubleType,true), StructField(orig_upb,DoubleType,true), StructField(orig_loan_term,IntegerType,true), StructField(orig_ltv,DoubleType,true), StructField(orig_cltv,DoubleType,true), StructField(num_borrowers,DoubleTy...\n"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val labelColName = \"delinquency_12\"\n",
    "val schema = StructType(List(\n",
    "  StructField(\"orig_channel\", DoubleType),\n",
    "  StructField(\"first_home_buyer\", DoubleType),\n",
    "  StructField(\"loan_purpose\", DoubleType),\n",
    "  StructField(\"property_type\", DoubleType),\n",
    "  StructField(\"occupancy_status\", DoubleType),\n",
    "  StructField(\"property_state\", DoubleType),\n",
    "  StructField(\"product_type\", DoubleType),\n",
    "  StructField(\"relocation_mortgage_indicator\", DoubleType),\n",
    "  StructField(\"seller_name\", DoubleType),\n",
    "  StructField(\"mod_flag\", DoubleType),\n",
    "  StructField(\"orig_interest_rate\", DoubleType),\n",
    "  StructField(\"orig_upb\", DoubleType),\n",
    "  StructField(\"orig_loan_term\", IntegerType),\n",
    "  StructField(\"orig_ltv\", DoubleType),\n",
    "  StructField(\"orig_cltv\", DoubleType),\n",
    "  StructField(\"num_borrowers\", DoubleType),\n",
    "  StructField(\"dti\", DoubleType),\n",
    "  StructField(\"borrower_credit_score\", DoubleType),\n",
    "  StructField(\"num_units\", IntegerType),\n",
    "  StructField(\"zip\", IntegerType),\n",
    "  StructField(\"mortgage_insurance_percent\", DoubleType),\n",
    "  StructField(\"current_loan_delinquency_status\", IntegerType),\n",
    "  StructField(\"current_actual_upb\", DoubleType),\n",
    "  StructField(\"interest_rate\", DoubleType),\n",
    "  StructField(\"loan_age\", DoubleType),\n",
    "  StructField(\"msa\", DoubleType),\n",
    "  StructField(\"non_interest_bearing_upb\", DoubleType),\n",
    "  StructField(labelColName, IntegerType)))\n",
    "\n",
    "val featureNames = schema.filter(_.name != labelColName).map(_.name).toArray\n",
    "\n",
    "val commParamMap = Map(\n",
    "  \"objective\" -> \"binary:logistic\",\n",
    "  \"num_round\" -> 100)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a new spark session and load data\n",
    "\n",
    "A new spark session should be created to continue all the following spark operations.\n",
    "\n",
    "NOTE: in this notebook, the dependency jars have been loaded when installing toree kernel. Alternatively the jars can be loaded into notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called just after a new spark session is created. Do it as below:\n",
    "\n",
    "```scala\n",
    "import org.apache.spark.sql.SparkSession\n",
    "val spark = SparkSession.builder().appName(\"mortgage-GPU\").getOrCreate\n",
    "%AddJar file:/data/libs/rapids-4-spark-XXX.jar\n",
    "%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\n",
    "%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\n",
    "// ...\n",
    "```\n",
    "\n",
    "##### Please note the new jar \"rapids-4-spark-XXX.jar\" is only needed for GPU version, you can not add it to dependence list for CPU version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "sparkSession: org.apache.spark.sql.SparkSession = org.apache.spark.sql.SparkSession@5d24fee\n",
       "reader: org.apache.spark.sql.DataFrameReader = org.apache.spark.sql.DataFrameReader@1926763c\n"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// Build the spark session and data reader as usual\n",
    "val sparkSession = SparkSession.builder.appName(\"mortgage-gpu\").getOrCreate\n",
    "val reader = sparkSession.read"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "trainSet: org.apache.spark.sql.DataFrame = [orig_channel: int, first_home_buyer: int ... 26 more fields]\n",
       "evalSet: org.apache.spark.sql.DataFrame = [orig_channel: int, first_home_buyer: int ... 26 more fields]\n",
       "transSet: org.apache.spark.sql.DataFrame = [orig_channel: int, first_home_buyer: int ... 26 more fields]\n"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val trainSet = reader.parquet(trainPath)\n",
    "val evalSet  = reader.parquet(evalPath)\n",
    "val transSet = reader.parquet(transPath)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set xgboost parameters and build a XGBoostClassifier\n",
    "\n",
    "For CPU version, `num_workers` is recommended being equal to the number of CPU cores, while for GPU version, it should be set to the number of GPUs in Spark cluster.\n",
    "\n",
    "Besides the `tree_method` for CPU version is also different from that for GPU version. Now only \"gpu_hist\" is supported for training on GPU.\n",
    "\n",
    "```scala\n",
    "// difference in parameters\n",
    "  \"num_workers\" -> 12,\n",
    "  \"tree_method\" -> \"hist\",\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "xgbParamFinal: scala.collection.immutable.Map[String,Any] = Map(objective -> binary:logistic, num_round -> 100, tree_method -> gpu_hist, num_workers -> 1)\n"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val xgbParamFinal = commParamMap ++ Map(\"tree_method\" -> \"gpu_hist\", \"num_workers\" -> 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "xgbClassifier: ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier = xgbc_a9ea95ac6122\n"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val xgbClassifier = new XGBoostClassifier(xgbParamFinal)\n",
    "      .setLabelCol(labelColName)\n",
    "      .setFeaturesCol(featureNames)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmark and train\n",
    "The object `benchmark` is used to compute the elapsed time of some operations.\n",
    "\n",
    "Training with evaluation sets is also supported in 2 ways, the same as CPU version's behavior:\n",
    "\n",
    "* Call API `setEvalSets` after initializing an XGBoostClassifier\n",
    "\n",
    "```scala\n",
    "xgbClassifier.setEvalSets(Map(\"eval\" -> evalSet))\n",
    "\n",
    "```\n",
    "\n",
    "* Use parameter `eval_sets` when initializing an XGBoostClassifier\n",
    "\n",
    "```scala\n",
    "val paramMapWithEval = paramMap + (\"eval_sets\" -> Map(\"eval\" -> evalSet))\n",
    "val xgbClassifierWithEval = new XGBoostClassifier(paramMapWithEval)\n",
    "```\n",
    "\n",
    "Here chooses the API way to set evaluation sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "res0: xgbClassifier.type = xgbc_a9ea95ac6122\n"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgbClassifier.setEvalSets(Map(\"eval\" -> evalSet))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defined object Benchmark\n"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "object Benchmark {\n",
    "  def time[R](phase: String)(block: => R): (R, Float) = {\n",
    "    val t0 = System.currentTimeMillis\n",
    "    val result = block // call-by-name\n",
    "    val t1 = System.currentTimeMillis\n",
    "    println(\"Elapsed time [\" + phase + \"]: \" + ((t1 - t0).toFloat / 1000) + \"s\")\n",
    "    (result, (t1 - t0).toFloat / 1000)\n",
    "  }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "------ Training ------\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.128.0.22, DMLC_TRACKER_PORT=38839, DMLC_NUM_WORKER=1}\n",
      "Elapsed time [train]: 32.351s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "xgbClassificationModel: ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel = xgbc_a9ea95ac6122\n"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// Start training\n",
    "println(\"\\n------ Training ------\")\n",
    "val (xgbClassificationModel, _) = Benchmark.time(\"train\") {\n",
    "  xgbClassifier.fit(trainSet)\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transformation and evaluation\n",
    "Here uses `transSet` to evaluate our model and prints some useful columns to show our prediction result. After that `MulticlassClassificationEvaluator` is used to calculate an overall accuracy of our predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "------ Transforming ------\n",
      "Elapsed time [transform]: 15.303s\n",
      "+------------+--------------+--------------------+--------------------+----------+\n",
      "|orig_channel|delinquency_12|       rawPrediction|         probability|prediction|\n",
      "+------------+--------------+--------------------+--------------------+----------+\n",
      "|           1|             0|[6.49920845031738...|[0.99849763081874...|       0.0|\n",
      "|           1|             0|[6.35734319686889...|[0.99826903396751...|       0.0|\n",
      "|           1|             0|[4.97297477722168...|[0.99312506755813...|       0.0|\n",
      "|           1|             0|[6.82305383682251...|[0.99891279113944...|       0.0|\n",
      "|           1|             0|[9.16866493225097...|[0.99989575525250...|       0.0|\n",
      "|           1|             0|[4.66329050064086...|[0.99065283033996...|       0.0|\n",
      "|           1|             0|[5.91289663314819...|[0.99730295152403...|       0.0|\n",
      "|           1|             0|[6.6304030418396,...|[0.99868210789281...|       0.0|\n",
      "|           1|             0|[7.32355213165283...|[0.99934062076499...|       0.0|\n",
      "|           1|             0|[7.49035978317260...|[0.99944186967331...|       0.0|\n",
      "+------------+--------------+--------------------+--------------------+----------+\n",
      "only showing top 10 rows\n",
      "\n",
      "\n",
      "------Accuracy of Evaluation------\n",
      "0.9861303878329984\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "results: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [orig_channel: int, first_home_buyer: int ... 29 more fields]\n",
       "evaluator: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = MulticlassClassificationEvaluator: uid=mcEval_ea35361b8ad3, metricName=f1, metricLabel=0.0, beta=1.0, eps=1.0E-15\n",
       "accuracy: Double = 0.9861303878329984\n"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "println(\"\\n------ Transforming ------\")\n",
    "val (results, _) = Benchmark.time(\"transform\") {\n",
    "  val ret = xgbClassificationModel.transform(transSet).cache()\n",
    "  ret.foreachPartition((_: Iterator[_]) => ())\n",
    "  ret\n",
    "}\n",
    "results.select(\"orig_channel\", labelColName,\"rawPrediction\",\"probability\",\"prediction\").show(10)\n",
    "\n",
    "println(\"\\n------Accuracy of Evaluation------\")\n",
    "val evaluator = new MulticlassClassificationEvaluator().setLabelCol(labelColName)\n",
    "val accuracy = evaluator.evaluate(results)\n",
    "println(accuracy)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the model to disk and load model\n",
    "Save the model to disk and then load it to memory. After that use the loaded model to do a new prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Elapsed time [transform2]: 0.306s\n",
      "+------------+----------------+------------+-------------+----------------+--------------+------------+-----------------------------+-----------+--------+------------------+--------+--------------+--------+---------+-------------+----+---------------------+---------+---+--------------------------+-------------------------------+------------------+-------------+--------+-------+------------------------+--------------+--------------------+--------------------+----------+\n",
      "|orig_channel|first_home_buyer|loan_purpose|property_type|occupancy_status|property_state|product_type|relocation_mortgage_indicator|seller_name|mod_flag|orig_interest_rate|orig_upb|orig_loan_term|orig_ltv|orig_cltv|num_borrowers| dti|borrower_credit_score|num_units|zip|mortgage_insurance_percent|current_loan_delinquency_status|current_actual_upb|interest_rate|loan_age|    msa|non_interest_bearing_upb|delinquency_12|       rawPrediction|         probability|prediction|\n",
      "+------------+----------------+------------+-------------+----------------+--------------+------------+-----------------------------+-----------+--------+------------------+--------+--------------+--------+---------+-------------+----+---------------------+---------+---+--------------------------+-------------------------------+------------------+-------------+--------+-------+------------------------+--------------+--------------------+--------------------+----------+\n",
      "|           1|               0|           1|            1|               1|             1|           1|                            1|          1|       1|              7.75| 61000.0|           180|    75.0|      0.0|          0.0| 0.0|                  0.0|        1|336|                       0.0|                              0|          30237.97|         7.75|    24.0|45300.0|                     0.0|             0|[6.49920845031738...|[0.99849763081874...|       0.0|\n",
      "|           1|               0|           1|            1|               1|             1|           1|                            1|          1|       1|              7.75| 61000.0|           180|    75.0|      0.0|          0.0| 0.0|                  0.0|        1|336|                       0.0|                              0|          45812.86|         7.75|     9.0|45300.0|                     0.0|             0|[6.35734319686889...|[0.99826903396751...|       0.0|\n",
      "|           1|               0|           1|            1|               1|             5|           1|                            1|          1|       1|               7.5|180000.0|           180|    80.0|      0.0|          0.0| 0.0|                  0.0|        1|305|                       0.0|                              0|          83273.47|          7.5|   120.0|    0.0|                     0.0|             0|[4.97297477722168...|[0.99312506755813...|       0.0|\n",
      "|           1|               0|           1|            1|               1|             5|           1|                            1|          1|       1|             7.625| 50000.0|           180|    33.0|      0.0|          0.0| 0.0|                  0.0|        1|301|                       0.0|                              0|          47657.12|        7.625|    14.0|12060.0|                     0.0|             0|[6.82305383682251...|[0.99891279113944...|       0.0|\n",
      "|           1|               0|           1|            1|               1|             6|           1|                            1|          1|       1|              6.99|174000.0|           360|    80.0|      0.0|          2.0|25.0|                749.0|        1|189|                       0.0|                              0|         173031.79|         6.99|     8.0|37980.0|                     0.0|             0|[9.16866493225097...|[0.99989575525250...|       0.0|\n",
      "|           1|               0|           1|            1|               1|             6|           1|                            1|          1|       1|               8.5| 73000.0|           360|    80.0|      0.0|          1.0|31.0|                608.0|        1|190|                       0.0|                              0|          71492.79|          8.5|    27.0|37980.0|                     0.0|             0|[4.66329050064086...|[0.99065283033996...|       0.0|\n",
      "|           1|               0|           1|            1|               1|             9|           1|                            1|          1|       1|               8.0|102000.0|           360|    80.0|      0.0|          0.0| 0.0|                  0.0|        1|606|                       0.0|                              0|         101530.01|          8.0|     8.0|16980.0|                     0.0|             0|[5.91289663314819...|[0.99730295152403...|       0.0|\n",
      "|           1|               0|           1|            1|               1|            22|           1|                            1|          2|       1|              7.75| 75000.0|           180|    47.0|      0.0|          2.0|23.0|                690.0|        1|298|                       0.0|                              0|          42430.25|         7.75|    95.0|12260.0|                     0.0|             0|[6.6304030418396,...|[0.99868210789281...|       0.0|\n",
      "|           1|               0|           1|            1|               1|            22|           1|                            1|          2|       1|              7.75| 75000.0|           180|    47.0|      0.0|          2.0|23.0|                690.0|        1|298|                       0.0|                              0|          68734.78|         7.75|    20.0|12260.0|                     0.0|             0|[7.32355213165283...|[0.99934062076499...|       0.0|\n",
      "|           1|               0|           1|            1|               3|            11|           1|                            1|          1|       1|             8.125|140000.0|           360|    90.0|      0.0|          2.0|43.0|                737.0|        1| 84|                      25.0|                              0|         138656.16|        8.125|     9.0|12100.0|                     0.0|             0|[7.49035978317260...|[0.99944186967331...|       0.0|\n",
      "+------------+----------------+------------+-------------+----------------+--------------+------------+-----------------------------+-----------+--------+------------------+--------+--------------+--------+---------+-------------+----+---------------------+---------+---+--------------------------+-------------------------------+------------------+-------------+--------+-------+------------------------+--------------+--------------------+--------------------+----------+\n",
      "only showing top 10 rows\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "modelFromDisk: ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel = xgbc_a9ea95ac6122\n",
       "results2: org.apache.spark.sql.DataFrame = [orig_channel: int, first_home_buyer: int ... 29 more fields]\n"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgbClassificationModel.write.overwrite.save(dataRoot + \"/model/\")\n",
    "\n",
    "val modelFromDisk = XGBoostClassificationModel.load(dataRoot + \"/model/\")\n",
    "\n",
    "val (results2, _) = Benchmark.time(\"transform2\") {\n",
    "  modelFromDisk.transform(transSet)\n",
    "}\n",
    "results2.show(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparkSession.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spylon-kernel",
   "language": "scala",
   "name": "spylon-kernel"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".scala",
   "help_links": [
    {
     "text": "MetaKernel Magics",
     "url": "https://metakernel.readthedocs.io/en/latest/source/README.html"
    }
   ],
   "mimetype": "text/x-scala",
   "name": "scala",
   "pygments_lexer": "scala",
   "version": "0.4.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
